# Copyright 2025 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Transition layers. Includes ReLUTransition, SwiGLUTransition,
ConditionedTransitionBlock, and StructureModuleTransition.
"""

from abc import ABC, abstractmethod

import torch
import torch.nn as nn
from ml_collections import ConfigDict

import openfold3.core.config.default_linear_init_config as lin_init
from openfold3.core.model.primitives import AdaLN, LayerNorm, Linear, SwiGLU
from openfold3.core.utils.checkpointing import checkpoint_section
from openfold3.core.utils.chunk_utils import chunk_layer


class Transition(nn.Module, ABC):
    @abstractmethod
    def _transition(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        pass

    @torch.jit.ignore
    def _chunk(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        chunk_size: int,
    ) -> torch.Tensor:
        return chunk_layer(
            self._transition,
            {"x": x, "mask": mask},
            chunk_size=chunk_size,
            no_batch_dims=len(x.shape[:-2]),
        )

    # TODO: Make this a more general function
    def _low_mem_ckpt_chunk(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        chunk_size: int,
        chunk_dim: int = -3,
    ) -> torch.Tensor:
        """
        Chunk and checkpoint the transition layer during training. Necessary for
        extreme cases where the backward pass of this module is too memory intensive.

        Args:
            x:
                [*, N, C_in] Input activation
            mask:
                [*, N] Input mask
            chunk_size:
                Chunk size over chunk dim
            chunk_dim:
                Dimension to chunk over

        Returns:
            [*, N, C_in] Loss for each sample
        """
        chunks = []
        for i in range(0, x.shape[chunk_dim], chunk_size):
            l_chunk = checkpoint_section(
                fn=self._transition,
                args=(
                    x[..., i : i + chunk_size, :, :],
                    mask[..., i : i + chunk_size, :, :],
                ),
                apply_ckpt=True,
                use_reentrant=False,
            )
            chunks.append(l_chunk)
        return torch.cat(chunks, dim=chunk_dim)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor | None = None,
        chunk_size: int | None = None,
        ckpt_chunk_size: int | None = None,
    ) -> torch.Tensor:
        """
        Args:
            x:
                [*, N, C_in] Input activation
            mask:
                [*, N] Input mask
            chunk_size:
                Chunk size for chunking the input tensor
            ckpt_chunk_size:
                Chunk size for activation checkpointing in the transition layer
        Returns:
            x:
                [*, N, C_in] Activation update
        """
        if mask is None:
            mask = x.new_ones(x.shape[:-1])

        # [*, N, 1]
        mask = mask.unsqueeze(-1)

        if ckpt_chunk_size is not None:
            x = self._low_mem_ckpt_chunk(
                x=x,
                mask=mask,
                chunk_size=ckpt_chunk_size,
            )
        elif chunk_size is not None:
            x = self._chunk(x=x, mask=mask, chunk_size=chunk_size)
        else:
            x = self._transition(x=x, mask=mask)

        return x


class ReLUTransitionLayer(nn.Module):
    """
    Feed-forward network applied to activations after attention.
    """

    def __init__(
        self, num_relu_layers, c_in, n, linear_init_params=lin_init.relu_transition_init
    ):
        """
        Args:
            num_relu_layers:
                Number of Linear+ReLU layers to apply.
            c_in:
                Input channel dimension
            n:
                Factor multiplied to c_in to obtain the hidden channel
                dimension
            linear_init_params:
                Linear layer initialization parameters
        """
        super().__init__()

        self.c_in = c_in
        self.n = n
        self.num_relu_layers = num_relu_layers

        self.layers = nn.ModuleList(
            [
                nn.Sequential(
                    Linear(self.c_in, self.n * self.c_in, **linear_init_params.layers),
                    nn.ReLU(),
                )
                for _ in range(self.num_relu_layers)
            ]
        )

        self.linear_out = Linear(self.n * self.c_in, self.c_in, init="final")

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x:
                [*, N, C_in] Input tensor
            mask:
                [*, N] Tensor mask
        Returns:
            x:
                [*, N, C_in] Tensor update
        """
        for l in self.layers:
            x = l(x)

        x = self.linear_out(x) * mask
        return x


class ReLUTransition(Transition):
    """
    Feed-forward network applied after attention.

    Implements AF2 Algorithm 9 and 15
    """

    def __init__(self, c_in, n, linear_init_params=lin_init.relu_transition_init):
        """
        Args:
            c_in:
                Input channel dimension
            n:
                Factor multiplied to c_in to obtain the hidden channel
                dimension
            linear_init_params:
                Linear layer initialization parameters
        """
        super().__init__()

        self.c_in = c_in
        self.n = n

        self.layer_norm = LayerNorm(self.c_in)
        self.transition_mlp = ReLUTransitionLayer(
            num_relu_layers=1,
            c_in=self.c_in,
            n=self.n,
            linear_init_params=linear_init_params,
        )

    def _transition(self, x, mask):
        x = self.layer_norm(x)
        x = self.transition_mlp(x=x, mask=mask)
        return x


class SwiGLUTransition(Transition):
    """Feed-forward network applied after attention.

    Implements AF3 Algorithm 11.
    """

    def __init__(
        self,
        c_in: int,
        n: int,
        linear_init_params: ConfigDict = lin_init.swiglu_transition_init,
    ):
        """
        Args:
            c_in:
                Input channel dimension
            n:
                Factor by which c_in is multiplied to obtain hidden channel
                dimension
            linear_init_params:
                Linear layer initialization parameters
        """
        super().__init__()

        self.c_in = c_in
        self.n = n

        self.layer_norm = LayerNorm(self.c_in)
        self.swiglu = SwiGLU(
            self.c_in, self.n * self.c_in, linear_init_params=linear_init_params.swiglu
        )
        self.linear_out = Linear(
            self.n * self.c_in, c_in, **linear_init_params.linear_out
        )

    def _transition(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # [*, N, C_in]
        x = self.layer_norm(x)

        # [*, N, C_hidden]
        x = self.swiglu(x)

        # [*, N, C_in]
        x = self.linear_out(x)
        x = x * mask

        return x


class ConditionedTransitionBlock(nn.Module):
    """SwiGLU transition block with adaptive layernorm.

    Implements AF3 Algorithm 25.
    """

    def __init__(
        self,
        c_a: int,
        c_s: int,
        n: int,
        linear_init_params: ConfigDict = lin_init.cond_transition_init,
    ):
        """

        Args:
            c_in:
                Input channel dimension
            n:
                Factor by which c_in is multiplied to obtain hidden channel
                dimension
            linear_init_params:
                Linear layer initialization parameters
        """
        super().__init__()

        self.c_a = c_a
        self.c_s = c_s
        self.n = n

        self.layer_norm = AdaLN(
            c_a=self.c_a, c_s=self.c_s, linear_init_params=linear_init_params.ada_ln
        )

        self.swiglu = SwiGLU(
            self.c_a, self.n * self.c_a, linear_init_params=linear_init_params.swiglu
        )

        self.sigmoid = nn.Sigmoid()
        self.linear_g = Linear(self.c_s, self.c_a, **linear_init_params.linear_g)
        self.linear_out = Linear(
            self.n * self.c_a, self.c_a, **linear_init_params.linear_out
        )

    def _transition(
        self, a: torch.Tensor, s: torch.Tensor, mask: torch.Tensor
    ) -> torch.Tensor:
        # [*, N, C_in]
        a = self.layer_norm(a, s)

        # [*, N, C_hidden]
        b = self.swiglu(a)

        # AdaLN-zero
        # [*, N, C_in]
        a = self.sigmoid(self.linear_g(s)) * self.linear_out(b)
        a = a * mask

        return a

    @torch.jit.ignore
    def _chunk(
        self,
        a: torch.Tensor,
        s: torch.Tensor,
        mask: torch.Tensor,
        chunk_size: int,
    ) -> torch.Tensor:
        return chunk_layer(
            self._transition,
            {"a": a, "s": s, "mask": mask},
            chunk_size=chunk_size,
            no_batch_dims=len(a.shape[:-2]),
        )

    def forward(
        self,
        a: torch.Tensor,
        s: torch.Tensor,
        mask: torch.Tensor | None = None,
        chunk_size: int | None = None,
    ) -> torch.Tensor:
        """
        Args:
            a:
                [*, N, C_in] Input activation
            s:
                [*, N, C_in] Input tensor to compute shift/scale
            mask:
                [*, N] Input mask
            chunk_size:
                Inference-time subbatch size

        Returns:
            a [*, N, C_in] Activation update
        """
        if mask is None:
            mask = a.new_ones(a.shape[:-1])

        mask = mask.unsqueeze(-1)

        if chunk_size is not None:
            a = self._chunk(a=a, s=s, mask=mask, chunk_size=chunk_size)
        else:
            a = self._transition(a=a, s=s, mask=mask)

        return a


class StructureModuleTransition(nn.Module):
    """Structure module transition.

    Implements AF2 Algorithm 20 lines 8-9.
    """

    def __init__(
        self,
        c,
        num_layers,
        dropout_rate,
        linear_init_params=lin_init.relu_transition_init,
    ):
        """
        Args:
            c: Input channel dimension
            num_layers: Number of ReLUTransitionLayers
            dropout_rate: Dropout rate
            linear_init_params: Linear layer initialization parameters
        """
        super().__init__()

        self.c = c
        self.num_layers = num_layers
        self.dropout_rate = dropout_rate

        self.layers = nn.ModuleList(
            [
                ReLUTransitionLayer(
                    num_relu_layers=2,
                    c_in=self.c,
                    n=1,
                    linear_init_params=linear_init_params,
                )
                for _ in range(self.num_layers)
            ]
        )

        self.dropout = nn.Dropout(self.dropout_rate)
        self.layer_norm = LayerNorm(self.c)

    def forward(self, s, mask=None):
        if mask is None:
            mask = s.new_ones(s.shape[:-1])

        mask = mask.unsqueeze(-1)

        for l in self.layers:
            s = s + l(s, mask)

        s = self.dropout(s)
        s = self.layer_norm(s)

        return s
